## Overview
This Notebook takes a KQL query and breaks it into batches that fit within the limits of the Azure Monitor API. This allows us to export more than the default 30,000 record/64MB limits experienced when using the native interface. The export will run the batches in parallel and write the data to local disk in the format specified in the OUTPUT_FORMAT parameter.

## 1. Install Dependencies
Run this cell to install the required Phython libraries.

In [None]:
import sys
!{sys.executable} -m pip install azure-monitor-query azure-identity pandas tqdm

## 2. Set Parameters
Modify the below parameters as necessary and then run the below cell.

In [None]:
from datetime import datetime, timedelta, timezone

#Required parameters:
START_TIME = datetime(2024, 12, 10, tzinfo=timezone.utc) #Start time of the time range for the query.
END_TIME = datetime(2024, 12, 15, tzinfo=timezone.utc) #End time of the time range for the query.
QUERY = "SecurityEvent | project TimeGenerated, Account" #KQL query to run

#If needed, change which Log Analytics workspace to use:
USE_DEFAULT_LAW_ID = True #If present, use the Log Analytics workspace ID that is present in the config.json file which gets created by Sentinel Notebooks.
LAW_ID = "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" #Log Analytics workspace ID to use if config.json file is not present or USE_DEFAULT_LAW_ID is set to False.

#Optional parameters used for performance and output tuning:
THREADS = 2 #Number of jobs to run in parallel. Typically, this should match the number of cores of the VM. Because the Azure Monitor API can only run 5 concurrent queries at a time, there are diminishing returns after a certain point.
AUTO_BATCH = True #Attempts to automatically detect optimial batch size (time range) to use when breaking up the query.
BATCH_SIZE = timedelta(hours=6) #If AUTO_BATCH is set to False, this batch size (time range) will be used to break up the query.
MIN_BATCH_SIZE = timedelta(minutes=1) #If the data returned cannot fit within this time range, we skip and move to the next batch.
OUTPUT_DIRECTORY = "./law_export" #Directory where results will be stored. A new directory gets created for each run.
OUTPUT_FILE_PREFIX = "query_results" #Prefix used for the data files containing the query results.
OUTPUT_FORMAT = 'CSV' #File format used to the store the query results on disk. CSV or PARQUET values are supported.
OUTPUT_COMBINE_FILES = True #Combine all job data files into a single file.
TIMEOUT = 3 #Number of minutes allowed before query times out. 10 minutes is max.

## 3. Export Data
Run the below cell to start the export process. Data will be written to local files in the directory specified in the OUTPUT_DIRECTORY parameter.

In [None]:
from datetime import datetime, timedelta, timezone
import pandas as pd
import time
from azure.monitor.query import LogsQueryClient, LogsQueryStatus
from azure.core.exceptions import HttpResponseError
from azure.identity import DefaultAzureCredential
import logging
import os
import glob 
import json
from multiprocessing import Pool, Manager
from tqdm import tqdm

class time_range_class:
    def __init__(self, name, start_time, end_time):
        self.name = name
        self.start_time = start_time
        self.end_time = end_time

def get_time_ranges(start_time=datetime.now(), end_time=datetime.now() - timedelta(hours=24), number_of_ranges=5):
    ranges = []
    interval = (end_time - start_time) / number_of_ranges
    delta = timedelta(microseconds=0)

    index = 0
    for i in range(number_of_ranges):
        range_name = "Job " + str(index) 
        range_start = end_time - ((i + 1) * interval)
        range_end = (end_time - (i * interval)) - delta
        time_range = time_range_class(range_name, range_start, range_end)
        ranges.append(time_range)
        index += 1
        delta = timedelta(microseconds=1)

    return ranges

def read_config_values(file_path):
    try:
        with open(file_path) as json_file:
            if json_file:
                json_config = json.load(json_file)
                return (json_config["workspace_id"])
    except:
        return None

def write_to_file(df, export_path, prefix, format):
    match format:
        case 'PARQUET':
            path = os.path.join(export_path, f"{prefix}.parquet")
            df.to_parquet(path)
        case 'CSV':
            path = os.path.join(export_path, f"{prefix}.csv")
            df.to_csv(path, index=False)    
    
def get_batch_size(query, law_id, start_time, end_time):
    batch_query = (f"{query}"
    "| summarize NumberOfBatchesBytes = 38400000 / avg(estimate_data_size(*)), NumberOfBatchesRows = count()"
    "| project NumberOfBatchesBytes = todecimal(NumberOfBatchesRows / NumberOfBatchesBytes), NumberOfBatchesRows = todecimal(NumberOfBatchesRows) / todecimal(450000)"
    "| project NumberOfBatches = round(iff(NumberOfBatchesBytes > NumberOfBatchesRows, NumberOfBatchesBytes, NumberOfBatchesRows), 2)"
    "| project NumberOfBatches = iif(NumberOfBatches < toreal(1), toreal(1), NumberOfBatches)")

    response = client.query_workspace(workspace_id=law_id, query=batch_query, timespan=(start_time, end_time))

    if response.status == LogsQueryStatus.SUCCESS:
        data = response.tables
    else:
        error = response.partial_error
        data = response.partial_data
        raise Exception(error.details[0]["innererror"])
    for table in data:
        df = pd.DataFrame(data=table.rows, columns=table.columns)
        
    return df['NumberOfBatches'].iloc[0]

def export_log_analytics_data(
    law_id: str,
    query: str,
    start_time: datetime = None,
    end_time: datetime = None,
    batch_size: timedelta = timedelta(hours=4),
    job_name: str = None,
    queue = None,
    min_batch_size: timedelta = timedelta(minutes=15),
    client: LogsQueryClient = None,
    export_path = '',
    export_prefix = 'query_results',
    auto_batch = True,
    export_format: str = 'CSV',
    timeout: int = 10,
    delay: int = 0,
    max_retries: int = 5,
    export_to_file: bool = True,
    json_depth: int = 10,
    ):

    time_range: timedelta = end_time - start_time
    error_count: int = 0
    initial_batch_size: timedelta = batch_size
    batch_count: timedelta = timedelta()
    current_count: int = 0
    percent_complete: int = 0
    stop_time: datetime = start_time
    time_range_format: str = 'd\dh\hm\ms\s'
    time_format: str = "%m-%d-%Y %H-%M-%S"
    runs_without_error_count: int = 0
    loop_done: bool = False
    rows_returned: int = 0
    results = []  
    

    logging.basicConfig(filename=f"{export_path}/{job_name}.log",
        filemode='a',
        format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
        datefmt='%H:%M:%S',
        level=logging.INFO)
    
    logging.FileHandler(f"{export_path}/{job_name}.log")

    logging.info(f"{job_name}: Starting new job.")

    if auto_batch == True: 
        try:
            batch_size = time_range / get_batch_size(query, law_id, start_time, end_time)
        except Exception as err:
            logging.error(f"{job_name}: Unhandled Error: {type(err)} {err}")
            return ({'job_name': job_name, 'status': 'error'})

    if batch_size > time_range: batch_size = time_range

    while error_count <= max_retries:
        try:
            while loop_done == False:
        
                if batch_size < initial_batch_size and runs_without_error_count > 5:
                    batch_size *= 2
                    logging.info(f"{job_name}: Increasing batch size to {batch_size}.")
                
                start_time = end_time - batch_size

                if start_time <= stop_time:
                    start_time = stop_time
                    batch_size = end_time - start_time
                    loop_done = True

                logging.info(f"{job_name}: Running query between {start_time.strftime(time_format)} and {end_time.strftime(time_format)}.")

                response = client.query_workspace(workspace_id=law_id, query=query, timespan=(start_time, end_time), timeout=timeout)

                if response.status == LogsQueryStatus.SUCCESS:
                    data = response.tables
                else:
                    error = response.partial_error
                    data = response.partial_data
                    raise Exception(error.details[0]["innererror"])
                for table in data:
                    df = pd.DataFrame(data=table.rows, columns=table.columns)
                    write_to_file(df, export_path, (f"{export_prefix}_{start_time.strftime(time_format)}"), export_format)
                
                batch_count += batch_size
                percent_complete_previous = percent_complete
                percent_complete = round((batch_count / time_range) * 100)
                logging.info(f"{job_name}: Received {df.shape[0]} rows of data and written to disk. Percent Complete: {percent_complete}")
                queue.put({'job_name': job_name, 'progress_update': (percent_complete - percent_complete_previous), 'rows_returned': int(df.shape[0])})
                rows_returned += int(df.shape[0])

                runs_without_error_count += 1
                end_time = start_time + timedelta(microseconds=-1)
                time.sleep(delay)

            logging.info(f"{job_name}: Finished exporting {rows_returned} records from Log Analytics. Percent Complete: 100")
            queue.put({'job_name': job_name, 'progress_update': (100 - percent_complete), 'rows_returned': 0})
            logging.Handler.close

            return ({'job_name': job_name, 'status': 'success', 'rows_returned_total': rows_returned})
        except Exception as err:
            if "Response ended prematurely" in str(err):
                logging.warning(f"{job_name}: Response ended prematurely, retrying. Message {type(err)} {err}") 
            elif ("Maximum response size of 100000000 bytes exceeded" in str(err) 
            or 'The results of this query exceed the set limit of 64000000 bytes' in str(err) 
            or 'The results of this query exceed the set limit of 500000 records' in str(err)):
                runs_without_error_count = 0
                if batch_size == min_batch_size:
                    logging.error(f"{job_name}: Results cannot be returned in the specified minimum batch size. Skipping batch between {start_time.strftime(time_format)} and {end_time.strftime(time_format)}. Message: {type(err)} {err}")
                    batch_count += batch_size
                    end_time = start_time + timedelta(microseconds=-1)
                    loop_done = False
                else:
                    batch_size = batch_size / 2
                    if batch_size < min_batch_size:
                        batch_size = min_batch_size
                    logging.info(f"{job_name}: Reduced batch size to: {batch_size}. Message: {type(err)} {err}")
                    loop_done = False
            else:
                logging.error(f"Unhandled Error: {type(err)} {err}")
                error_count += 1
                if error_count > max_retries:
                    logging.error(f"{job_name}: Max number of retries reached, exiting.")
                    return ({'job_name': job_name, 'status': 'error'})
        finally:
            logging.Handler.close
  

time_format: str = "%m-%d-%Y %H-%M-%S"
if not os.path.exists(OUTPUT_DIRECTORY): os.makedirs(OUTPUT_DIRECTORY)
job_directory = f"{OUTPUT_DIRECTORY}/{datetime.now().strftime(time_format)}"
os.mkdir(job_directory)

ranges = get_time_ranges(start_time=START_TIME, end_time=END_TIME, number_of_ranges=THREADS )

completed_jobs = []
failed_jobs = []
last_queue_time = datetime.now()

def log_result(result):
    global completed_jobs
    if result['status'] == 'success':
        completed_jobs.append(result)
    else:
        print(f"{result['job_name']} has failed. Please check log file for details.")
        failed_jobs.append(result)

def log_error(error):
    print(error)

def cleanup():
    pbar.clear()
    pbar.close()
    pool.close()
    pool.join()

workspace_id = read_config_values('config.json')

if workspace_id == None or USE_DEFAULT_LAW_ID == False:
    if LAW_ID != "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx":
        law_id = LAW_ID
    else:
        raise Exception("Please specify a valid Log Analyics workspace ID in the Parameters cell.")
else:
    law_id = workspace_id

manager = Manager()
queue = manager.Queue()

credential = DefaultAzureCredential()
client = LogsQueryClient(credential)

pool = Pool()

pbar = tqdm(total=THREADS * 100, leave=True, position=0)
pbar.set_description(f"Splitting query into {THREADS} jobs for parallel processing", refresh=True)

for i in ranges:
    result = pool.apply_async(export_log_analytics_data, [law_id, QUERY, i.start_time, i.end_time, BATCH_SIZE, i.name, queue, MIN_BATCH_SIZE, client, job_directory, OUTPUT_FILE_PREFIX, AUTO_BATCH, OUTPUT_FORMAT, TIMEOUT], callback=log_result, error_callback=log_error) 

while (len(completed_jobs) + len(failed_jobs)) < THREADS or not queue.empty():
    if not queue.empty():
        item = queue.get()
        last_queue_time = datetime.now()
        pbar.update(item['progress_update'])
    if datetime.now() - last_queue_time > timedelta(minutes=TIMEOUT):
        pbar.set_description(f"No input received from running job(s) for more than {TIMEOUT} minutes, check logs for errors. Exiting.", refresh=True)
        if len(completed_jobs) > 0: print(f"Completed jobs: {', '.join([item['job_name'] for item in completed_jobs])}.")
        cleanup()
        break
    time.sleep(1)   

def concat(csv_file):
    return pd.read_csv(csv_file, low_memory=False) 

if len(completed_jobs) > 0:
    pbar.set_description(f"Export of {sum([item['rows_returned_total'] for item in completed_jobs])} records to {job_directory} complete", refresh=True)
    if OUTPUT_COMBINE_FILES == True:
        csv_files = glob.glob(job_directory + '/*.{}'.format('csv'))
        df = pd.concat(map(concat, csv_files), ignore_index=True)
        df.to_csv(job_directory + '/' + OUTPUT_FILE_PREFIX + '_FullExport.csv', index=False)
else:
    pbar.set_description(f"No jobs completed successfully. Please check log files in {job_directory} for details.")

cleanup()

## 4. Cleanup
Run the below cell to delete all run data including logs and data files.

In [None]:
import shutil

try:
    shutil.rmtree(OUTPUT_DIRECTORY)
    print("Data has been deleted.")
except Exception as err:
    print(f"Error deleting data: {err}")
